{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "37a2fdbf-1fc9-433f-ac19-b073d95ad154",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# 0. Preliminaries"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15c2bd16-09cd-4dd6-b4b3-14b424ea4764",
   "metadata": {},
   "source": [
    "Before starting, here are some prerequisites for this tutorial.\n",
    "\n",
    "### 💻 Environment requirements\n",
    "This project was tested with:\n",
    "\n",
    "- Linux OS (Windows Subsystem for Linux _**might**_ work but we do not officially support it)\n",
    "- 64G RAM\n",
    "- NVIDIA GTX 1080 Ti 11G, NVIDIA V100 32G, NVIDIA A40 48G\n",
    "- CUDA 11.8 and 12.1\n",
    "- conda 23.3.1\n",
    "\n",
    "### 🏗  Installation\n",
    "As indicated in our [README](../README.md), simply run [`install.sh`](install.sh) to install all dependencies in a new conda environment \n",
    "named `spt`. \n",
    "\n",
    "```bash\n",
    "# Creates a conda env named 'spt' env and installs dependencies\n",
    "./install.sh\n",
    "```\n",
    "\n",
    "### 👩‍💻 Coding experience\n",
    "Being familiar with the following is _**mandatory**_:\n",
    "- [Python](https://www.python.org/)\n",
    "- [Jupyter](https://jupyter.org/)\n",
    "- [PyTorch](https://pytorch.org/docs/stable/index.html/)\n",
    "- [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/)\n",
    "\n",
    "Knowledge of the following would also be _**nice to have**_:\n",
    "- [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/)\n",
    "- [Hydra](https://hydra.cc/docs/intro/)\n",
    "- [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template)\n",
    "\n",
    "Finally, having a look at our [README](../README.md) would help you better _**navigate our code structure**_.\n",
    "\n",
    "### 🧑‍🎓 Machine learning experience\n",
    "Whether you intend to **simply understand, make use of, or extend** our method, we **strongly encourage you to read (and cite) our paper [_Efficient 3D Semantic Segmentation with Superpoint Transformer_](https://arxiv.org/abs/2306.08045)** (ICCV 2023).\n",
    "\n",
    "Besides, if you are not very familiar with 3D deep learning and self-attention, some important papers might provide a bit more context for this work:\n",
    "- [Transformer](https://arxiv.org/abs/1706.03762) (NeurIPS 2017)\n",
    "- [PointNet](https://arxiv.org/abs/1612.00593) (CVPR 2017)\n",
    "- [Superpoint Graph](https://arxiv.org/abs/1711.09869) (CVPR 2018)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "969c39a3-f0a6-449d-bdb2-f71a9f9cb86c",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# 1. Introduction"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3dd7a80a-9ef6-4d27-9005-7114014f7461",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 👉 [Introductory slides](../media/superpoint_transformer_tutorial.pdf)\n",
    "### 👉 [Tutorial video](https://www.youtube.com/watch?v=2qKhpQs9gJw)\n",
    "\n",
    "This tutorial will demonstrate how to use Superpoint Transformer (SPT) on your own point cloud data. \n",
    "\n",
    "In our running example, we will use a large point cloud from the [Vancouver LiDAR 2022](https://opendata.vancouver.ca/explore/dataset/lidar-2022/map/?location=12,49.25683,-123.14421) dataset and run inference on it with SPT pretrained on [DALES](https://udayton.edu/engineering/research/centers/vision_lab/research/was_data_analysis_and_processing/dale.php), a similar dataset for which we officially provide pretrained weights.\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img width=\"33%\" src=\"../media/dales/sem_gt_demo.png\">\n",
    "</p>\n",
    "\n",
    "Although they both cover similar urban areas, the DALES and Vancouver datasets are far from identical: different semantic segmentation classes, sensor noise and resolution might differ, and Vancouver provides pointwise RGB colors and LiDAR intensity, while DALES only has LiDAR intensity. For these reasons, we will likely want to parametrize and train SPT on Vancouver data rather than just using a DALES-pretrained model. This tutorial will give you guidelines on how to proceed.\n",
    "\n",
    "In particular, we will cover the following:\n",
    "\n",
    "&emsp;✅ reading and visualizing raw point clouds using `Data` objects <br/>\n",
    "&emsp;✅ runnning an inference on custom data using a pretrained SPT <br/>\n",
    "&emsp;✅ parametrizing the preprocessing of the hierarchical superpoint partition on custom data <br/>\n",
    "&emsp;✅ training SPT on a custom dataset <br/>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "227e82fb-c57d-4b80-9911-f66c004c287c",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# 2. Reading and visualizing raw point clouds using `Data` objects"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0a96951-efec-4fc2-be89-d2a864bad441",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 2.1. Preparing a `Data` reader"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5358289-f087-46c8-a381-29ff709c6bc3",
   "metadata": {},
   "source": [
    "Before anything, you will need to define a reader function that parses your raw point cloud files (eg LAS, PLY, txt, ...) and returns a `Data` object holding your points and associated attributes.\n",
    "\n",
    "Our `Data` object is a simple class based on [PyG's `Data` object](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data) for holding points clouds (or more generally graphs) with convenient utilities.\n",
    "\n",
    "Below we provide a ready-to-use example of such reader for parsing LAS files from the [Vancouver LiDAR 2022](https://opendata.vancouver.ca/explore/dataset/lidar-2022/map/?location=12,49.25683,-123.14421) dataset.\n",
    "\n",
    "> **Tip 💡**: You can find inspiration from other point cloud readers implemented for our supported datasets in `src.datasets`. In particular, for PLY format, you may want to have a look at the source code for DALES and KITTI-360."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "540aef5c-3f9e-41ae-8c8a-c557e307c36a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "# Add the project's files to the python path\n",
    "# file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # for .py script\n",
    "file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook\n",
    "sys.path.append(file_path)\n",
    "\n",
    "import laspy\n",
    "import torch\n",
    "from src.data import Data\n",
    "from src.utils.color import to_float_rgb\n",
    "\n",
    "\n",
    "def read_vancouver_tile(\n",
    "        filepath, \n",
    "        xyz=True, \n",
    "        rgb=True, \n",
    "        intensity=True, \n",
    "        semantic=True, \n",
    "        instance=False,\n",
    "        remap=True, \n",
    "        max_intensity=600):\n",
    "    \"\"\"Read a Vancouver tile saved as LAS.\n",
    "\n",
    "    :param filepath: str\n",
    "        Absolute path to the LAS file\n",
    "    :param xyz: bool\n",
    "        Whether XYZ coordinates should be saved in the output Data.pos\n",
    "    :param rgb: bool\n",
    "        Whether RGB colors should be saved in the output Data.rgb\n",
    "    :param intensity: bool\n",
    "        Whether intensity should be saved in the output Data.rgb\n",
    "    :param semantic: bool\n",
    "        Whether semantic labels should be saved in the output Data.y\n",
    "    :param instance: bool\n",
    "        Whether instance labels should be saved in the output Data.obj\n",
    "    :param remap: bool\n",
    "        Whether semantic labels should be mapped from their Vancouver ID\n",
    "        to their train ID\n",
    "    :param max_intensity: float\n",
    "        Maximum value used to clip intensity signal before normalizing \n",
    "        to [0, 1]\n",
    "    \"\"\"\n",
    "    # Create an emty Data object\n",
    "    data = Data()\n",
    "    \n",
    "    las = laspy.read(filepath)\n",
    "\n",
    "    # Populate data with point coordinates \n",
    "    if xyz:\n",
    "        # Apply the scale provided by the LAS header\n",
    "        pos = torch.stack([\n",
    "            torch.tensor(las[axis])\n",
    "            for axis in [\"X\", \"Y\", \"Z\"]], dim=-1)\n",
    "        pos *= las.header.scale\n",
    "        pos_offset = pos[0]\n",
    "        data.pos = (pos - pos_offset).float()\n",
    "        data.pos_offset = pos_offset\n",
    "\n",
    "    # Populate data with point RGB colors\n",
    "    if rgb:\n",
    "        # RGB stored in uint16 lives in [0, 65535]\n",
    "        data.rgb = to_float_rgb(torch.stack([\n",
    "            torch.FloatTensor(las[axis].astype('float32') / 65535)\n",
    "            for axis in [\"red\", \"green\", \"blue\"]], dim=-1))\n",
    "\n",
    "    # Populate data with point LiDAR intensity\n",
    "    if intensity:\n",
    "        # Heuristic to bring the intensity distribution in [0, 1]\n",
    "        data.intensity = torch.FloatTensor(\n",
    "            las['intensity'].astype('float32')\n",
    "        ).clip(min=0, max=max_intensity) / max_intensity\n",
    "\n",
    "    # Populate data with point semantic segmentation labels\n",
    "    if semantic:\n",
    "        y = torch.LongTensor(las['classification'])\n",
    "        data.y = torch.from_numpy(ID2TRAINID)[y] if remap else y\n",
    "\n",
    "    # Populate data with point panoptic segmentation labels\n",
    "    if instance:\n",
    "        raise NotImplementedError(\"The dataset does not contain instance labels.\")\n",
    "\n",
    "    return data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9b30b8e-9cd1-42f0-aa48-15d889e6d631",
   "metadata": {},
   "source": [
    "Often, we need to remap the raw labels provided in a dataset to another set of labels to be used for training. \n",
    "In the next cell, we define some environment variables for remapping Vancouver class indices and corresponding customized class names and colors for downstream visualization.\n",
    "\n",
    "> **Tip 💡**: As described in our [datasets documentation](../docs/datasets.md/#semantic-label-format) we consider labels in `[0, num_classes - 1]` to be valid classes and use the `num_classes` label for void/ignored/unlabeled points (whichever you call it). Check out the [documentation](../docs/datasets.md/#semantic-label-format) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd4b4c2e-8d2b-4bf7-be1f-ea9b48077523",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Number of classes in the dataset (excluding void/unlabeled/ignored)\n",
    "VANCOUVER_NUM_CLASSES = 6\n",
    "\n",
    "# Mapping from original classes\n",
    "ID2TRAINID = np.asarray([\n",
    "    VANCOUVER_NUM_CLASSES,  # 0 Not used         ->  6 Ignored\n",
    "    5,                      # 1 Other            ->  5 Other\n",
    "    0,                      # 2 Ground           ->  0 Ground\n",
    "    3,                      # 3 Low vegetation   ->  3 Low vegetation\n",
    "    VANCOUVER_NUM_CLASSES,  # 4 Unknown / Noise  ->  6 Ignored\n",
    "    2,                      # 5 High vegetation  ->  2 High vegetation\n",
    "    4,                      # 6 Building         ->  4 Buildings\n",
    "    VANCOUVER_NUM_CLASSES,  # 7 Unknown / Noise  ->  6 Ignored\n",
    "    VANCOUVER_NUM_CLASSES,  # 8 Unknown / Noise  ->  6 Ignored\n",
    "    1])                     # 9 Water            ->  1 Water\n",
    "\n",
    "# Class names (including void/unlabeled/ignored last)\n",
    "VANCOUVER_CLASS_NAMES = [\n",
    "    'Ground',\n",
    "    'Water',\n",
    "    'High vegetation',\n",
    "    'Low vegetation',\n",
    "    'Buildings',\n",
    "    'Other',\n",
    "    'Ignored']\n",
    "\n",
    "# Class color palette (including void/unlabeled/ignored last)\n",
    "VANCOUVER_CLASS_COLORS = np.asarray([\n",
    "    [243, 214, 171],\n",
    "    [169, 222, 249],\n",
    "    [ 70, 115,  66],\n",
    "    [204, 213, 174],\n",
    "    [214,  66,  54],\n",
    "    [186, 160, 164],\n",
    "    [  0,   0,   0]])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd7e1cba-0171-4b7e-9f94-c8d5eb68a1c9",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 2.2. `Data` visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0fcdbba-77f6-46dd-a7d3-ff722ffa9886",
   "metadata": {},
   "source": [
    "We can now download tiles from [Vancouver LiDAR 2022](https://opendata.vancouver.ca/explore/dataset/lidar-2022/map/?location=12,49.25672,-123.14434) and read their content into a `Data` object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33b330d9-5390-49ab-a50e-7e2648e357ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "filepath = '/path/to/your/vancouver.las'\n",
    "data = read_vancouver_tile(filepath)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2408734-1519-4e9e-9c2b-86c537e45ecc",
   "metadata": {},
   "source": [
    "We have created a `Data` object containing out point cloud and associated attributes. \n",
    "Let's have a closer look at it !\n",
    "\n",
    "The basic `Data.__repr__()` will show the attributes (ie keys) in Data and their respective shapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a211df2-5ae2-44a2-8457-33e1b5b2bff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c29fa630-639a-4ae2-af2e-b6abaff9dd58",
   "metadata": {},
   "source": [
    "You can check the number of points (ie nodes) in a `Data` object with `data.num_points` (or `data.num_nodes`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a948af-4e95-4636-b3a1-6264fba13dc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.num_points"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffebf559-d70b-437b-bbc8-b18b11bc5152",
   "metadata": {},
   "source": [
    "You can check the list of attributes stored in a `Data` object with `data.keys`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5c278cd-92a7-47ff-bf57-d3eceb15476b",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.keys"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08d6f14c-9862-42ea-adc2-1ba529618b51",
   "metadata": {},
   "source": [
    "We provide a [Plotly](https://plotly.com/python)-based too for visalizing `Data` objects. To use it, simply use `data.show()`. This function offers many options for customizing your plot. We will see later on that it can also be used for visualizing hierarchical superpoint partitions held in `NAG` objects.\n",
    "\n",
    "First, let's visualize the whole point cloud contained in `Data` (this may take a couple of seconds if your cloud has $\\sim10^5$ points or more).\n",
    "We can specify our `class_names` and `class_colors` to `show()` to customize the displaying of semantic segmentation labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd37a2ec-4149-4af9-ae3f-938bd5b9729c",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.show(class_names=VANCOUVER_CLASS_NAMES, class_colors=VANCOUVER_CLASS_COLORS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41cc2c77-9121-4e05-8d14-34b2b93a17d3",
   "metadata": {},
   "source": [
    "By default, the point cloud is subsampled`max_points=50000` to alleviate the visualization computation time.\n",
    "To get a clearer, high-resolution view, you can increase `max_points` or visualize smaller scenes.\n",
    "You can for instance, only display a spherical crop of the point cloud by specifying a `center` and a `radius`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eca61c3c-5900-4807-896a-1d018f2f40b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.show(center=[425, 282, 15], radius=30, keys=['intensity'], class_names=VANCOUVER_CLASS_NAMES, class_colors=VANCOUVER_CLASS_COLORS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39655658-791f-411e-8c41-b6d9d92a90ba",
   "metadata": {},
   "source": [
    "> **Tips 💡**\n",
    "> - More info on our `Data` structure ? 👉 see [`docs/data_structures.md`](../docs/data_structures.md), our source code in `src.data.data`, and the [PyG Data documentation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data) it builds upon\n",
    "> - More info on our `show()` visualization tool ? 👉 see [`docs/visualization.md`](../docs/visualization.md) and  source code in `src.visualization`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e40ca7d-baab-458a-8b32-7ea17459bc63",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# 3. Tiling very large point clouds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a1c9e99-71c5-4347-9676-8ff2946137e6",
   "metadata": {},
   "source": [
    "Sometimes, aerial or terrestrial LiDAR acquisition campaigns produce point cloud files covering extremely large areas 🐘. \n",
    "For instance, DALES and Vancouver datasets provide tiles at 50 pts/m² resolution spanning 0.25 km² and 1 km², respectively.\n",
    "\n",
    "While Superpoint Transformer is quite scalable, the CPU and GPU memories 💾 still put a limit on how big of a scene can be processed at once.\n",
    "Since we usually do not _need_ to jointly process all the points in a 500m radius for understanding the scene semantics, it is safe to **tile these into smaller chunks of manageable size**.\n",
    "\n",
    "We propose two tiling strategies in this project: \n",
    "- tiling along the XY coordinate system axes with `SampleXYTiling` 👉 when your clouds already have simple, convex, axis-aligned horizontal layouts like DALES or Vancouver\n",
    "- recursively tiling along the principal XY components with `SampleRecursiveMainXYAxisTiling` 👉 when your clouds have complex horizontal layouts like KITTI-360\n",
    "\n",
    "Let's visualize the impact of the tilings on our current `Data` object (we will run the below example on subsampled data for the sake of faster visualization).\n",
    "\n",
    "> **Tip 💡**: You can skip this section if your `Data` is not that large (eg $\\sim 10⁶$ points or fewer with a 24G-32G GPU 🦋). You can still adjust the tiling later on to suit your point cloud size and hardware capabilities if you run into memory issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a747670-8327-4ae3-a773-7ed1ca9e61e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.transforms import SampleXYTiling, GridSampling3D\n",
    "from src.data import Batch\n",
    "\n",
    "# Tile the cloud into `xy_tiling` XY-oriented chunks of equal horizontal \n",
    "# span\n",
    "xy_tiling = (4, 2)\n",
    "\n",
    "# Voxelize the point cloud only for the sake of faster computation and \n",
    "# visualization here\n",
    "data_5m = GridSampling3D(10)(data)\n",
    "\n",
    "# Compute each chunk \n",
    "chunks = []\n",
    "for x in range(xy_tiling[0]):\n",
    "    for y in range(xy_tiling[1]):        \n",
    "        # Extract the chunk at (x, y) in the tiling grid\n",
    "        chunk = SampleXYTiling(x=x, y=y, tiling=xy_tiling)(data_5m)\n",
    "\n",
    "        # Add a 'tile' attribute to the points for visualization\n",
    "        chunk.tile = torch.full((chunk.num_points,), x * xy_tiling[1] + y)\n",
    "        \n",
    "        # Store the chunk for later aggregation\n",
    "        chunks.append(chunk)\n",
    "\n",
    "# Aggregate all chunk `Data` objects into one big `Data` object\n",
    "data_tiled = Batch.from_data_list(chunks)\n",
    "\n",
    "# Show the resulting `Data' with the 'tile' attribute\n",
    "data_tiled.show(keys='tile')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bde0716a-b605-4353-b014-174fbf215fe1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.transforms import SampleRecursiveMainXYAxisTiling, GridSampling3D\n",
    "from src.data import Batch\n",
    "\n",
    "# Recursively tile the cloud into `2**pc_tiling` chunks with respect to \n",
    "# principal components of the XY coordiantes\n",
    "pc_tiling = 3\n",
    "\n",
    "# Voxelize the point cloud only for the sake of faster computation and \n",
    "# visualization here\n",
    "data_5m = GridSampling3D(5)(data)\n",
    "\n",
    "# Compute each chunk \n",
    "chunks = []\n",
    "for x in range(2**pc_tiling):\n",
    "    # Extract the chunk at x in the recursive tiling\n",
    "    chunk = SampleRecursiveMainXYAxisTiling(x=x, steps=pc_tiling)(data_5m)\n",
    "\n",
    "    # Add a 'tile' attribute to the points for visualization\n",
    "    chunk.tile = torch.full((chunk.num_points,), x)\n",
    "    \n",
    "    # Store the chunk for later aggregation\n",
    "    chunks.append(chunk)\n",
    "\n",
    "# Aggregate all chunk `Data` objects into one big `Data` object\n",
    "data_tiled = Batch.from_data_list(chunks)\n",
    "\n",
    "# Show the resulting `Data' with the 'tile' attribute\n",
    "data_tiled.show(keys='tile')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "745d55d5-5a36-4bb3-8dfb-096166a30a0f",
   "metadata": {},
   "source": [
    "Since the Vancouver point cloud is XY-axis aligned and has a simple square XY layout, we choose to use `SampleXYTiling` here.\n",
    "**For the rest of this tutorial, we will work on one of the chunks of the original point cloud.**\n",
    "Feel free to adjust the tiling method and the chosen tile to your dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5331d667-0a1b-4108-a601-7b5aa46f47e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.transforms import SampleXYTiling\n",
    "\n",
    "# Extract the chunk at (x, y) in the tiling grid\n",
    "data = SampleXYTiling(x=1, y=1, tiling=3)(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed04e913-8bd0-4a56-99ff-c27f6fd5b2be",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# 4. Using a pretrained model for inference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa568587-4408-481d-b533-d8990d791402",
   "metadata": {},
   "source": [
    "We provide pretrained weights and preprocessing parametrization for several datasets (see [README](../README.md) and [datasets documentation](../docs/datasets.md)). Since the Vancouver dataset is fairly similar to DALES, we would like to check how a DALES-pretrained SPT would fare on our present `Data` object.\n",
    "\n",
    "As mentioned in the [introductory slides](../media/superpoint_transformer_tutorial.pdf), running an inference with a pretrained SPT requires more than just the model weights. Indeed, we also need to apply to our `Data` the same `pre_transform` and `on_device_transform` as the ones used for training the model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86a5b9e6-aad2-47c5-a6cb-c0d776fb8133",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 4.1. Instantiating transforms from `configs/`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37863cae-bd91-461c-b5b9-1cd59c56aed7",
   "metadata": {},
   "source": [
    "We will first need to recover the transforms used in the DALES experiments as provided in the `configs/experiment` using [Hydra](https://hydra.cc/docs/intro/). \n",
    "In the next cell, we show how to use the `init_config()` utility to get the **exact configuration used for training the released DALES model**.\n",
    "\n",
    "> **Tips 💡**\n",
    "> - More info on how `configs/` & [Hydra](https://hydra.cc/docs/intro/) work ? 👉 see the [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template) repository\n",
    "> - More info on a specific experiment's settings ? 👉 explore our configuratin files in `configs/`, these are fairly commented 😉"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79b45daf-4ced-4eed-ab3f-fbc1b7a2f9b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils import init_config\n",
    "\n",
    "cfg = init_config(overrides=[f\"experiment=semantic/dales\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b48e129c-8a95-4a27-9133-292a341aeeac",
   "metadata": {},
   "source": [
    "This `cfg` is an [omegaconf](https://omegaconf.readthedocs.io) `DictConfig` object. It contains all the necessary hyperparameters for reproducing the pretraining experiment: dataset, model structure, training recipe, etc. We can explore its content just like a basic dictionary, or a simple object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c864cd81-c4f3-4ec0-a628-fdbb5e4787c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0f56c8b-935d-4cd8-ad1e-9a61dab72dba",
   "metadata": {},
   "source": [
    "The parametrization of the transforms is specified in the datamodule config in `cfg.datamodule`.\n",
    "We can instantiate the transforms from an [omegaconf](https://omegaconf.readthedocs.io) `DictConfig` object without instantiating the whole dataset by using the `instantiate_datamodule_transforms()` utility."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ade93e52-97ac-4e86-9a40-c3c08f9e0c62",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.transforms import instantiate_datamodule_transforms\n",
    "\n",
    "transforms_dict = instantiate_datamodule_transforms(cfg.datamodule)\n",
    "transforms_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "258199eb-dd95-4769-b5c0-b6ea2e38edb2",
   "metadata": {},
   "source": [
    "The transforms are chained operations applied to a `Data` or a `NAG` object. Their order and parametrization plays a significant role and modifying these may have non-negligible downstream effects. **These must be thought as part of the model itself**."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38d0dbf9-62ca-45b6-b5e4-40ff827bf25d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 4.2. Applying transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e268c002-8179-4900-9ce0-3a16cb8f89e8",
   "metadata": {},
   "source": [
    "As explained in the [introductory slides](../media/superpoint_transformer_tutorial.pdf), we will be using `pre_transform` and `on_device_test_transform` to reproduce the behavior of the pretrained model at inference time.\n",
    "\n",
    "> **Note 🤓**: In the next cell, we manually apply some `NAGRemoveKeys()` transform after the `pre_transform`. This is because we ocasionally need to mimick the full behavior of the pretraining `Dataset`: after the `pre_transform` is executed, the preprocessed `NAG` is saved to disk. When later read from disk by the `Dataset`, only the `point_load_keys` attributes of `NAG[0]` and `segment_load_keys` attributes of `NAG[i], i>0` are loaded from disk. This mechanism ensures we only load the strict necessary during training, hence saving I/O time. Since we are running the `pre_transform` manually here, we need to account for this mechanism and discard the preprocessed attributes that the DALES dataset did not read from disk. These can be found in `cfg.datamodule.point_load_keys` and `cfg.datamodule.segment_load_keys`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f41463ce-5f93-4a60-a276-a759f638f14b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply pre-transforms\n",
    "nag = transforms_dict['pre_transform'](data)\n",
    "\n",
    "# Simulate the behavior of the dataset's I/O behavior with only\n",
    "# `point_load_keys` and `segment_load_keys` loaded from disk\n",
    "from src.transforms import NAGRemoveKeys\n",
    "nag = NAGRemoveKeys(level=0, keys=[k for k in nag[0].keys if k not in cfg.datamodule.point_load_keys])(nag)\n",
    "nag = NAGRemoveKeys(level='1+', keys=[k for k in nag[1].keys if k not in cfg.datamodule.segment_load_keys])(nag)\n",
    "\n",
    "# Move to device\n",
    "nag = nag.cuda()\n",
    "\n",
    "# Apply on-device transforms\n",
    "nag = transforms_dict['on_device_test_transform'](nag)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7a83c80-844d-4f71-ae85-4b1fcfbdce9d",
   "metadata": {},
   "source": [
    "The output of the transforms is no longer a `Data` object, but a `NAG`. This is the data structure we use to carry around **point clouds** and **hierarchical superpoint partitions**. \n",
    "\n",
    "Essentially, it is a list of `Data` objects, each representing a partition level:\n",
    "- `nag[0]` is $P_0$, the (voxelized) points\n",
    "- `nag[i]` is $P_i$, the $\\text{i}^\\text{th}$ superpoint partition level \n",
    "\n",
    "At each level $i>0$, the `edge_index` and `edge_attr` attributes carry the **superpoint adjacency graph** and corresponding **adjacency features**.\n",
    "\n",
    "> **Tip 💡** More info on our `NAG` structure ? 👉 see [`docs/data_structures.md`](../docs/data_structures.md) and source code in `src.data.nag`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33e6b0ea-eda3-4825-b5bc-8e35fafb8de3",
   "metadata": {},
   "outputs": [],
   "source": [
    "nag"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fb7ce51-8ae5-4982-8442-831be1d9cf84",
   "metadata": {},
   "source": [
    "Let's visualize the impact of the transforms on the data on a small area for high-resolution display. Note we can display the nodes and edges of the superpoint graphs by passing `show(centroids=True, h_edge=True)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f52994e7-9433-4c7d-8c6f-46c922deb957",
   "metadata": {},
   "outputs": [],
   "source": [
    "nag.show(class_names=VANCOUVER_CLASS_NAMES, class_colors=VANCOUVER_CLASS_COLORS, center=[485, 505, 0], radius=20, keys=nag[0].keys, centroids=True, h_edge=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e221bacf-a499-479d-9468-cfada4e0b3d9",
   "metadata": {},
   "source": [
    "Now we have preprocessed our data, we need to run an inference with the pretrained model.\n",
    "\n",
    "> **Tip 💡**: If you want to store your progress disk, both `Data` and `NAG` have `.save()` and `.load()` methods specially designed with fast I/O and disk usage in mind 😉."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4847c821-855d-4913-9e4b-a35c200203e7",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 4.3. Instantiating a pretrained model from `configs/` and a `*.ckpt`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "066d2505-dc66-439b-b7f0-02d4267fd2e1",
   "metadata": {},
   "source": [
    "Similar to the transforms, we will use the DALES experiment configuration files to instantiate the **pretrained model**. \n",
    "This time, the part of the [omegaconf](https://omegaconf.readthedocs.io) `DictConfig` object we are interested in is stored under `cfg.model`.\n",
    "\n",
    "As stated in the [README](../README.md), the pretrained weights for our models can be recovered from [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.8042712.svg)](https://doi.org/10.5281/zenodo.8042712)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89222181-655a-4a3b-a9f2-2ab2983f9adb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import hydra \n",
    "from src.utils import init_config\n",
    "\n",
    "# Path to the checkpoint file downloaded from https://zenodo.org/records/8042712\n",
    "ckpt_path = \"/path/to/your/superpoint_transformer.ckpt\"\n",
    "\n",
    "cfg = init_config(overrides=[f\"experiment=semantic/dales\"])\n",
    "\n",
    "# Instantiate the model and load pretrained weights\n",
    "model = hydra.utils.instantiate(cfg.model)\n",
    "model = model._load_from_checkpoint(ckpt_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6241d311-5843-4ad0-a48b-463d33be255c",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 4.4. Applying SPT"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32ef6135-8b2e-499c-937f-4e3ee8afe5bc",
   "metadata": {},
   "source": [
    "Now everything is ready for running our inference ! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "143c799f-f7a5-4a2f-8a4d-c378191a0009",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the model in inference mode on the same device as the input\n",
    "model = model.eval().to(nag.device)\n",
    "\n",
    "# Inference, returns a task-specific ouput object carrying predictions\n",
    "with torch.no_grad():\n",
    "    output = model(nag)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4712247-4af5-4a5f-b7cc-b75c4f7372aa",
   "metadata": {},
   "source": [
    "The output of the model is a `SemanticSegmentationOutput` object. It is a simple class dedicated to holding onto predictions in `output.semantic_pred()` and facilitating certain basic post-processing operations such as metrics computation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a48bc881-aa88-435d-b989-8dc4af7cdbe6",
   "metadata": {},
   "outputs": [],
   "source": [
    "output.semantic_pred().shape, nag.num_points"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b97c1af4-a6c9-4f14-a553-8ed46098c417",
   "metadata": {},
   "source": [
    "As stated in [introductory slides](../media/superpoint_transformer_tutorial.pdf), it is important to remember that, by default, **SPT outputs predictions on the $P_1$ level** (ie `nag[1]`). Since the superpoints $P_1$ are assumed to be semantically pure, simply classifying those is equivalent to classifying each point in the scene. In doing so, we save a lot of computation and memory during training.\n",
    "\n",
    "Yet, at inference time, we often want the predictions at the voxel level $P_0$ (ie `nag[0]`) or even at the full-resolution of the raw input cloud. \n",
    "To this end, we simply need to distribute the $P_1$ predictions to the lower partition levels.\n",
    "The `SemanticSegmentationOutput.voxel_semantic_pred()` and `SemanticSegmentationOutput.full_res_semantic_pred()` were designed just for that ! \n",
    "\n",
    "In the next cell, we will convert $P_1$ predictions into $P_0$ predictions.\n",
    "\n",
    "> **Tip 💡**: For **full-resolution predictions**, see our [`demo.ipynb` notebook](../notebooks/demo.ipynb), and have a look at [`src.utils.output_semantic.py`](../src/utils/output_semantic.py#L140). Remember that if you have applied a tiling to your data, your full-resolution predictions will be given for the tile at hand and not the original point cloud.\n",
    "\n",
    "> **Note 🤓**: Although SPT does make predictions as $P_1$ node classifications, all losses and metrics are properly computed so as to take into account the true labels assigned to full-resolution points. To make these efficient, our pipeline always tracks the **histogram of ground truth labels** for each voxel in $P_0$ and superpoint in $P_i, i>0$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51f59a31-3887-4c6a-98b4-5aa54ee97c96",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the level-0 (voxel-wise) semantic segmentation predictions \n",
    "# based on the predictions on level-1 superpoints and save those for \n",
    "# visualization in the level-0 Data under the 'semantic_pred' attribute\n",
    "nag[0].semantic_pred = output.voxel_semantic_pred(super_index=nag[0].super_index)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "625b77c8-21a6-42d5-ae51-7271c551b07e",
   "metadata": {},
   "source": [
    "Let's visualize the resulting predictions on a small area for high-resolution display.\n",
    "\n",
    "Note that since the model was trained on DALES classes, the predicted labels do not align with those of our Vancouver dataset. \n",
    "For better visualization, we will use the DALES `CLASS_NAMES` and `CLASS_COLORS`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51d31e7d-5f8f-44db-9765-7d426dfaca42",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.datasets.dales import CLASS_NAMES as DALES_CLASS_NAMES\n",
    "from src.datasets.dales import CLASS_COLORS as DALES_CLASS_COLORS\n",
    "\n",
    "nag.show(class_names=DALES_CLASS_NAMES, class_colors=DALES_CLASS_COLORS, center=[485, 505, 0], radius=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50ac5aa5-0c63-47a7-aa40-45c15e0f3d6d",
   "metadata": {},
   "source": [
    "We can see that the DALES-pretrained model is actually doing a pretty good job on the Vancouver dataset !\n",
    "\n",
    "Still, the DALES classes and Vancouver classes are not the same. If we are particularly interested in identifying Vancouver classes such as _low vegetation_, or _water_, we will need to train a dedicated model on the Vancouver data. Besides, we may also want to adjust the preprocessing steps in `pre_transform`: different parameters may produce partitions that better respect the semantic boundaries of Vancouver classes."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fdf2863-2ffc-4ca6-b67e-566c4c548917",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# 5. Parametrizing the partition"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fd03f39-ce30-4bde-af02-ca429614ee18",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 5.1. Assessing partition quality"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43c6443c-0732-4ff0-a6e9-5304346ade17",
   "metadata": {},
   "source": [
    "There are many ways to parametrize your preprocessing `pre_transform`, and finding a good setting for a new dataset is usually a matter of 'just trying'. \n",
    "Still, there are some **guidelines for what a _good_ partition might be**:\n",
    "- **⚡ efficiency** -  it must simplify the scene by having as few superpoints as possible 👉 measured with `NAG.level_ratios`\n",
    "> `NAG.level_ratios` computes the ratio of the number of elements between successive partition levels. \n",
    "- **🎯 accuracy** - it must respect the semantic boundaries of objects 👉 measured with `Data.semantic_segmentation_oracle()`\n",
    "> `Data.semantic_segmentation_oracle()` computes the semantic segmentation metrics of a hypothetical _oracle_ model capable of predicting the majority label for each superpoint. To compute this, we use the fact that labels in `nag[i].y` are stored as histograms, which allows for computing _exact_ full-resolution metrics (even accounting for the voxelization of $P_0$).\n",
    "\n",
    "#### 🔮 **Rules of thumb** - _Don't take these for granted for any dataset but aiming for those can get you started._\n",
    "\n",
    "We usually aim for:\n",
    "- $\\frac{|P_0|}{|P_1|} \\in [30, 50]$\n",
    "- $\\frac{|P_i|}{|P_{i+1}|} \\in [3, 10],\\quad i > 0$\n",
    "- $\\text{oracle mIoU} ~ P_1 > 0.95$\n",
    "\n",
    "Beyond these quantified measurements, it is also important that you **visualize your partitions** and, given your own domain expertise, check whether they make sense for the task you are interested in.\n",
    "\n",
    "Let's check the efficiency and accuracy of the current partition on the `NAG` at hand.\n",
    "\n",
    "> **Tip 💡**: In practice you would want to compute and accumulate these values on your entire dataset, or at least on several representative tiles. Here we only compute these on a single tile for simplicity. Scaling the present single-tile study to multiple tiles will be up to you, but we would recommend you implement your own `Dataset` for that (see next section) 😉"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "797a203e-924c-4c72-a156-0aef35de5371",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ratio of sizes of successive partition levels\n",
    "nag.level_ratios"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46e9c8e0-5b62-4a51-b27d-5cee21cf3276",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Oracle semantic segmentation metrics on P_1\n",
    "nag[1].semantic_segmentation_oracle(VANCOUVER_NUM_CLASSES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0b13cda-14f1-484d-87d5-35b91f867ada",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Oracle semantic segmentation metrics on P_0\n",
    "nag[0].semantic_segmentation_oracle(VANCOUVER_NUM_CLASSES)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07ffe997-2e2f-4d62-8808-1685fb9a3a5d",
   "metadata": {},
   "source": [
    "As we can see, the partition is not so bad, but we may want to improve the $P_1$ oracle mIoU a little."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27563ead-5ab0-4511-a23b-5f3807386775",
   "metadata": {},
   "source": [
    "### 5.2. Adjusting the partition parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "838fc7e8-a568-4973-8635-c9bcca51f89c",
   "metadata": {},
   "source": [
    "As mentioned in the [introductory slides](../media/superpoint_transformer_tutorial.pdf), the `pre_transform` typically includes the following steps:\n",
    "\n",
    "| Operation | `Transform` |\n",
    "| :------ | :------ |\n",
    "| Voxelization | [`GridSampling3D`](../src/transforms/sampling.py#L59C7-L59C21) |\n",
    "| Neighbor search | [`KNN`](../src/transforms/neighbors.py#L9)  |\n",
    "| Elevation estimation | [`GroundElevation`](../src/transforms/point.py#L223) |\n",
    "| Pointwise local geometric features | [`PointFeatures`](..src/transforms/point.py#L17) |\n",
    "| Adjacency graph | [`AdjacencyGraph`](..src/transforms/graph.py#L24) |\n",
    "| Hierarchical partition | [`CutPursuitPartition`](../src/transforms/partition.py#L23) |\n",
    "| Superpoint-wise handcrafted features | [`SegmentFeatures`](../src/transforms/graph.py#L75) |\n",
    "| Superpoint adjacency graph and features | [`RadiusHorizontalGraph`](../src/transforms/graph.py#L548) |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26f1a2dd-d683-407b-ba17-a40161a87799",
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms_dict['pre_transform']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0f085ad-7bb8-4724-aba7-a2310dcb7035",
   "metadata": {},
   "source": [
    "Let's have a quick look at how what some of these operations affect the `Data` object. To this end, we will re-read the raw data to start from scratch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b6b150a-d394-4b05-a7e6-2d26a6338e83",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.transforms import *\n",
    "\n",
    "data = read_vancouver_tile(filepath)\n",
    "data = SampleXYTiling(x=1, y=1, tiling=3)(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5a86610-7c7f-4786-b299-1b12a841de1d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "#### 5.2.1. Voxelization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87040791-01b1-4218-89ac-280b5ab488af",
   "metadata": {},
   "source": [
    "`GridSampling3D(size=...)` voxelizes a point cloud to voxel `size`. This first step is not specific to Superpoint Transformer, it is shared by most point cloud preprocessing pipelines, even not explicitly voxel-based. This **mitigates sampling density disparities** and reduces the size of the point cloud, hence **reducing downstream compute and memory costs**.\n",
    "\n",
    "> **Tip 💡**: Keep the [Nyquist-Shannon theorem](https://en.wikipedia.org/wiki/Nyquist%E2%80%93Shannon_sampling_theorem) in mind when deciding on a voxel size. You typically want your voxel resolution to be **at least half the size of the smallest structure you want to characterize**. This puts a lower bound the voxel sizes you should consider. Still, using smaller voxels (ie higher point resolution) usually comes with higher model performance at the expense of compute and memory efficiency.\n",
    "\n",
    "> **Note 🤓**: `GridSampling3D` offers advanced mechanisms for aggregating point attributes inside each voxel, based on their nature:\n",
    "> - mean aggregation (eg for float values like the position or colors)\n",
    "> - last encountered value (eg for identical values like the batch index)\n",
    "> - histogram (eg for semantic segmentation labels)\n",
    "> - voting for dominant value (eg for semantic segmentation labels, superpoint indices)\n",
    "> - merging into a `Cluster` object (eg for full-resolution point indices)\n",
    "> - unit-normalized vector combination (eg for normals)\n",
    "> \n",
    "> See the [source code](../src/transforms/sampling.py#L59C7-L59C21) for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5dee99d-f7ff-4814-997d-b2aa5a626073",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_voxelized = GridSampling3D(size=1, hist_key='y', hist_size=VANCOUVER_NUM_CLASSES + 1)(data)\n",
    "data_voxelized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ced4743-fe84-45e4-8373-3469ffa7fbbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.num_nodes / data_voxelized.num_nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4868bda-fd16-4291-84dd-42d146cac2c4",
   "metadata": {},
   "source": [
    "In our case the already-selected DALES voxel resolution of 10 cm is well-adapted for Vancouver, so we will keep it as is."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7851c44-c8f3-42eb-93b5-5b1682f1f09a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = GridSampling3D(size=0.1, hist_key='y', hist_size=VANCOUVER_NUM_CLASSES + 1)(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b895472d-8303-44c6-951d-b36a521f518b",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "#### 5.2.2. Neighbor search"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9ef0789-4e87-47d8-9441-9b938db2f1d8",
   "metadata": {},
   "source": [
    "`KNN(k=..., r_max=...)` searches for the `k` nearest neighbors of each point, within a maximum radius of `r_max`. Contrary to basic K-NN search, the radius constraint prevents spurious neighborhoods for very sparse areas of the point cloud. By design, this approach implies **points may not all have the same number of neighbors**, depending on the local geometry and density. Our pipeline is capable of dealing of neighborhoods of uneven sizes, without resorting to artificial subsampling or oversampling strategies.\n",
    "\n",
    "Applying `KNN` will store the results in `neighbor_index` and `neighbor_distance` attributes. Missing neighbors are indicated as `-1` in `neighbor_index`.\n",
    "\n",
    "The neigbors are used for two things in the preprocessing pipeline:\n",
    "- computing local geometric features with `PointFeatures`, later used by `CutPursuitPartition` as pointwise signal for the superpoint partition\n",
    "- computing the adjacency graph with `AdjacencyGraph`, later used by `CutPursuitPartition` as the graph on which the superpoint partition is computed\n",
    "\n",
    "> **Note 🤓**: Our fast `KNN` implementation internally relies on [`FRNN`](https://github.com/lxxue/FRNN) which is optimized for **GPU-based neighbor search**. While it offers considerable speedups compared to other off-the-shelf neighbor search libraries, its installation has revealed challenging to some users. We might move to a slightly-slower-but-more-stable CPU-based [`nanoflann`](https://github.com/jlblancoc/nanoflann/tree/c4c4daf6bb9bda9890fb58324282016b4184d887) implementation in the future. If you are having troubles installing `FRNN`, check [related solved issues](https://github.com/drprojects/superpoint_transformer/issues?q=frnn) in the repository."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913064a0-1e2e-42a2-9022-8fbb7621af27",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = KNN(k=25, r_max=2)(data)\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e00cf568-1dda-4af5-b6f7-614b3c14c3d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "(data.neighbor_index == -1).sum() / data.num_nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f8880c0-54fb-4bde-9bf1-7a842f4f0961",
   "metadata": {},
   "source": [
    "Let's visualize the number of neighbors per point when the `r_max` is smaller than what is set for DALES."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43bb4400-f98f-4ccb-a3cf-4a399ad6d038",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.num_neighbors = data.neighbor_index.ge(0).sum(dim=1)\n",
    "data.show(class_names=VANCOUVER_CLASS_NAMES, class_colors=VANCOUVER_CLASS_COLORS, center=[485, 505, 0], radius=20, keys='num_neighbors')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac19a741-9283-4c0d-81c2-28fc6dee134e",
   "metadata": {},
   "source": [
    "#### 5.2.3. Elevation estimation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ae9bcdb-e806-4d39-afb3-62aa06b7d0b4",
   "metadata": {},
   "source": [
    "`GroundElevation` is used to look for the ground among the points, to then infer point `elevation`. Indeed, the elevation is a more informative feature than the `z` coordinate of points for semantic parsing. For real-life large point cloud acquisitions, the absolute `z` value usually carries no meaning, but the _relative `z`_ with respect to the ground does (the same holds for absolute `x` and `y` values).\n",
    "\n",
    "To find the ground, we simply use the [RANSAC](https://en.wikipedia.org/wiki/Random_sample_consensus) algorithm. `GroundElevation(threshold=..., scale=...)` will search for the ground as a planar surface located within `threshold` of the lowest point in the cloud. Pointwise distance to the plane will then be computed and normalized by `scale`. `threshold` should be tuned for environments where other large planar surfaces may affect the RANSAC ground search (eg ceiling, building roof, bridges, below-ground water surface, ...).\n",
    "\n",
    "> **Note 🤓**: Using RANSAC to represent the ground surface is a _coarse and error-prone_ strategy. While it was sufficient for the benchmark datasets used in our paper, more advanced tools should be used for capturing with non-planar outdoor terrain or multi-floor indoor scans. See the `GroundElevation` documentation for more options for pre-filtering non-ground points and more advanced ground surface models !"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edf3b9a2-1109-4839-867a-0159c17b22cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = GroundElevation(threshold=5, scale=20)(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5495a2e1-2d0b-4ad2-8838-3d5e4a708f51",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "sns.displot(data.elevation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d344d65-009b-4db0-8e88-1bbe31748796",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.show(class_names=VANCOUVER_CLASS_NAMES, class_colors=VANCOUVER_CLASS_COLORS, keys='elevation')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8b34397-4f6e-438e-b572-bcaaa21d0840",
   "metadata": {},
   "source": [
    "For the relatively-flat Vancouver dataset, it seems the DALES parametrization of `GroundElevation` is good enough."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "082b850e-136c-40f0-a524-69da84e4373d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "#### 5.2.4. Pointwise local geometric features"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d8e7ef6-3b82-4066-b5be-f89e4cb30d7e",
   "metadata": {},
   "source": [
    "`PointFeatures` computes some handcrafted geometric features characterizing each point's neighborhood. The following features are currently supported:\n",
    "- RGB color\n",
    "- HSV color\n",
    "- LAB color\n",
    "- density\n",
    "- linearity\n",
    "- planarity\n",
    "- scattering\n",
    "- verticality\n",
    "- normal\n",
    "- length\n",
    "- surface\n",
    "- volume\n",
    "- curvature\n",
    "\n",
    "These features should be computed with the superpoint partition in mind: these will be the **criteria based on which points will or will not grouped together** by the cut-pursuit algorithm.\n",
    "\n",
    "The choice of which feature is useful to your problem will depend on your classes of interest. For instance, when studying anthropic structures, planarity and linearity are very important. Note that the robustness and expressivity of these computed geometric features will depend on your `KNN` parametrization. If your point clouds come with RGB colors, converting those to HSV or LAB colorspaces may help capturing object boundaries (cf [SLIC](https://ieeexplore.ieee.org/document/6205760) paper).\n",
    "\n",
    "Interestingly, Vancouver has RGB colors, which was not the case for DALES. Let's see if using these instead of the LiDAR intensity improves the partition. \n",
    "\n",
    "> **Note 🤓**: `PointFeatures` supports various strategies for geometric computation. By default, all neighbors produced by `KNN` will be used. One may also specify `PointFeatures(k_min=...)` below which a point will receive `0` geometric features, to mitigate the low-quality features for too-small neighborhoods. Besides, `PointFeatures(k_step=..., k_min_search=...)` will search for the optimal neighborhood size among available neighbors for each point, based on eigenfeatures entropy (based on this [paper](https://isprs-annals.copernicus.org/articles/II-3/181/2014/isprsannals-II-3-181-2014.pdf))."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fdf7c5b-e041-417d-9b7f-37525e1b7d5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = PointFeatures(keys=('elevation', 'rgb', 'hsv', 'linearity', 'planarity', 'scattering', 'verticality'))(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6293bc89-2fdd-454c-ae82-34622361f07d",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.show(class_names=VANCOUVER_CLASS_NAMES, class_colors=VANCOUVER_CLASS_COLORS, center=[485, 505, 0], radius=20, keys=data.keys)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38dbc78b-234c-491c-8351-063079f9f77d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "#### 5.2.5. Adjacency graph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3330050e-a83d-45cc-9b2d-9c2e0888fb00",
   "metadata": {},
   "source": [
    "`AdjacencyGraph` computes the adjacency graph based on which the superpoint partition will be computed. It is relying on the output of `KNN` to find neighbors for each point. `AdjacencyGraph(k=..., w=...)` will store edges for the `k`-NN graph in `Data.edge_index`, along with edge weights in `Data.edge_attr` to be used in the partition (the larger the an edge's weight the harder to separate the corresponding points)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da751094-5530-4ec2-84fc-eb98332cfae9",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = AdjacencyGraph(k=10, w=1)(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc865c98-b2ae-40d2-b885-04428d00910b",
   "metadata": {},
   "source": [
    "#### 5.2.6. Hierarchical partition"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebf411c2-3b29-4d2c-9eae-845c9354ac73",
   "metadata": {},
   "source": [
    "`CutPursuitPartition` is where the actual superpoint partition occurs. The [parallel cut-pursuit](https://arxiv.org/abs/1905.02316) algorithm is used to partition the adjacency graph based on point features. A regularization term rules the trade-off between \"many-superpoint-with-homogeneous-content\" and \"few-superpoints-with-heterogenous-content\".\n",
    "\n",
    "In `CutPursuitPartition(regularization=..., spatial_weight=..., k_adjacency=..., cutoff=...)`, `regularization` carries a list of increasing float values for coarser and coarser hierarchical superpoint partition levels. `spatial_weight` indicates how much importance the point coordinates play with respect point features, when grouping points: the larger the weight, the more spatial coordinates take over, the more tesselated-looking the partition. `k_adjacency` prevents superpoints from staying isolated. `cutoff` rules the minimum number of points in each superpoint partition level: too-small superpoint will be merged with other superpoints.\n",
    "\n",
    "Before computing the partition, we need to move to the `x` attribute all the features that we want to use for the partition (`CutPursuitPartition` will blindly use whatever it finds `x`). To this end, we will use the `AddKeysTo` transform.\n",
    "\n",
    "You can play with the features used with `AddKeysTo` and `CutPursuitPartition` parameters, and see how it impacts your partition metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7629bf6-457e-4960-8e07-945d70ba2097",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copy desired features to `x`\n",
    "data = AddKeysTo(keys=['linearity', 'planarity', 'scattering', 'elevation'], to='x', delete_after=False)(data)\n",
    "\n",
    "# Compute the hierarchical partition\n",
    "nag = CutPursuitPartition(\n",
    "    regularization=[0.1, 0.2], \n",
    "    spatial_weight=[0.1, 0.01], \n",
    "    cutoff=[10, 30], \n",
    "    iterations=15, \n",
    "    k_adjacency=10)(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb6c5039-1918-4f48-aefe-61e66a4954de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ratio of sizes of successive partition levels\n",
    "nag.level_ratios"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af5b8a29-1499-49ea-8d24-58b47af09507",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Oracle semantic segmentation metrics\n",
    "nag[1].semantic_segmentation_oracle(VANCOUVER_NUM_CLASSES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42f21fd3-5f97-453e-a5d7-14cc399432cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "nag.show(class_names=VANCOUVER_CLASS_NAMES, class_colors=VANCOUVER_CLASS_COLORS, center=[485, 505, 0], radius=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6515c77a-873f-486c-9888-2bbb916e3864",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "#### 5.2.7. Superpoint-wise handcrafted features"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "811819b6-8af9-4caf-8675-ebe123f16ad3",
   "metadata": {},
   "source": [
    "Once the hierarchical partition has been computed, `SegmentFeatures` builds some superpoint-wise features at each partition level. These are basic descriptors that can be used to help the model characterize the superpoints or the connection between superpoints (see `RadiusHorizontalGraph`)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5a2e346-57be-4d81-afe7-93698f9e9627",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "#### 5.2.8. Superpoint adjacency graph and features"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08e92fc6-cb0c-4271-8cef-b4773aaeee87",
   "metadata": {},
   "source": [
    "`RadiusHorizontalGraph` computes the superpoint adjacency graphs and stores them in the `edge_index` and `edge_attr` of each partition level. These are the graphs used by SPT to propagate information between nodes with self-attention. \n",
    "\n",
    "In particular `RadiusHorizontalGraph(gap=..., k_min=..., k_max=...)` rule how far each superpoint is allowed to look inside each partition level. You can think of this as the **\"kernel size\" of the attention mechanism**. While increasing `gap` may increase model semantic segmentation performance, be aware that it will also increase the number of edges in the adjacency graph, which will directly impact computation and memory efficiency.\n",
    "\n",
    "Once we are happy with our partition parametrization, we will want to deploy the preprocessing to our entire dataset and train on it."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7137bc78-7b0b-4c13-a63e-b9ef1a7393a4",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# 6. Training on your own `Dataset`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a3b5ee7-3b9b-4c56-aaf5-3967bc19f7d9",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 6.1. Creating your own `Dataset`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed6558b4-36df-4889-b423-ce5aa6af2280",
   "metadata": {},
   "source": [
    "To make the most of the codebase capabilities, your dataset must inherit from the `BaseDataset` class and follow a certain structure. See the [datasets documentation](../docs/datasets.md) for how to implement your own dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29aa9194-3fc8-45db-8350-fa33bd9a1b27",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 6.2. Parametrizing your transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be647680-8be0-463a-a65e-b0f60dfe8772",
   "metadata": {},
   "source": [
    "We have seen above how to configure your `pre_transform`. Since these will be executed once on your dataset at preprocessing time, you will not need to re-run them at each experiment if you leave the parameters unchanged.\n",
    "\n",
    "Still, you will also need to parametrize your `on_device_train_transform` and `on_device_val_transform` (usually we fix `on_device_test_transform=on_device_val_transform`).\n",
    "Exploring these is outside of the scope of this tutorial, but several design choices can have an impact on your model performance and memory consumption.\n",
    "\n",
    "Have a look at how we configured already-existing datasets in the `configs/datamodule/semantic/` for reference. Besides, the source code of all transforms is fairly documented. Make sure you read it to understand what they do !\n",
    "\n",
    "> **Tips 💡**:\n",
    "> - It is possible to parametrize SPT to **train and infer on an 11G GPU 💾**. To this end, you can have a look at the existing '*_11G' configs in `configs/experiment/semantic` to see how we did for supported datasets.\n",
    "> - We provide a **detailed list of suggestions for troubleshooting CUDA memory errors** in the [README](https://github.com/drprojects/superpoint_transformer/tree/master?tab=readme-ov-file#cuda-out-of-memory-errors)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "118a9a17-6425-49a4-9c5f-df60104fb98a",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### 6.3. Training and testing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2794e208-9a67-461e-925e-b23f4ef49936",
   "metadata": {},
   "source": [
    "Have a look at the [README](https://github.com/drprojects/superpoint_transformer/tree/master?tab=readme-ov-file#cuda-out-of-memory-errors) for basic training and testing commands.\n",
    "Refer to the documentation of the [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template) to make the most of all the _**many available functionalities**_ !"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:spt] *",
   "language": "python",
   "name": "conda-env-spt-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
